{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ur8xi4C7S06n"
      },
      "outputs": [],
      "source": [
        "# Copyright 2024 Google LLC\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JAPoU8Sm5E6e"
      },
      "source": [
        "# Running a Gemma 2-based agentic RAG with Ollama on Vertex AI and LangGraph\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/open-models/serving/vertex_ai_ollama_gemma2_rag_agent.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg\" alt=\"Google Colaboratory logo\"><br> Open in Colab\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Fopen-models%2Fserving%2Fvertex_ai_ollama_gemma2_rag_agent.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN\" alt=\"Google Cloud Colab Enterprise logo\"><br> Open in Colab Enterprise\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/open-models/serving/vertex_ai_ollama_gemma2_rag_agent.ipynb\">\n",
        "      <img src=\"https://www.gstatic.com/images/branding/gcpiconscolors/vertexai/v1/32px.svg\" alt=\"Vertex AI logo\"><br> Open in Vertex AI Workbench\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/serving/vertex_ai_ollama_gemma2_rag_agent.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://raw.githubusercontent.com/primer/octicons/refs/heads/main/icons/mark-github-24.svg\" alt=\"GitHub logo\"><br> View on GitHub\n",
        "    </a>\n",
        "  </td>\n",
        "</table>\n",
        "\n",
        "<div style=\"clear: both;\"></div>\n",
        "\n",
        "<b>Share to:</b>\n",
        "\n",
        "<a href=\"https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/serving/vertex_ai_ollama_gemma2_rag_agent.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg\" alt=\"LinkedIn logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/serving/vertex_ai_ollama_gemma2_rag_agent.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg\" alt=\"Bluesky logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/serving/vertex_ai_ollama_gemma2_rag_agent.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/5a/X_icon_2.svg\" alt=\"X logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/serving/vertex_ai_ollama_gemma2_rag_agent.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png\" alt=\"Reddit logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/serving/vertex_ai_ollama_gemma2_rag_agent.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg\" alt=\"Facebook logo\">\n",
        "</a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "84f0f73a0f76"
      },
      "source": [
        "| | |\n",
        "|-|-|\n",
        "| Author(s) |  [Ivan Nardini](https://github.com/inardini) |"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tvgnzT1CKxrO"
      },
      "source": [
        "## Overview\n",
        "\n",
        "> [**Gemma 2**](https://ai.google.dev/gemma) is a new generation of open models developed by Google. It offers pre-trained and instruction-tuned variants in two sizes (2B, 7B and 9b parameters), designed for high performance and efficiency on various hardware.  Gemma 2 models are available through platforms like Google AI Studio, Kaggle, and Hugging Face.\n",
        "\n",
        "> [**Ollama**](https://github.com/ollama/ollama) is a tool for running open-source large language models (LLMs) locally.  It simplifies LLM usage by bundling model weights, configurations, and datasets into a single package managed by a [`Modelfile`](https://github.com/ollama/ollama/blob/main/docs/modelfile.md). Ollama supports various models like LLaMA-2, Mistral, and CodeLLaMA, and is compatible with macOS and Linux.\n",
        "\n",
        "> [**LangGraph**](https://python.langchain.com/en/latest/modules/graphs/langgraph.html) is a framework developed by LangChain for building applications with complex workflows, including agents and multi-agent systems. It offers precise control over application flow and state, supporting cyclical graphs and advanced state management.  LangGraph enhances LangChain's capabilities, providing more flexibility for agentic applications.\n",
        "\n",
        "> [**Google Vertex AI**](https://cloud.google.com/vertex-ai) is Google Cloud's unified machine learning (ML) platform.  It provides a comprehensive suite of tools for building, training, deploying, and managing ML models and AI applications, including large language models (LLMs). Vertex AI streamlines the entire ML workflow, from data management to prediction, and supports customization for specific business needs.\n",
        "\n",
        "This notebook showcases how to run a Gemma 2-based Agent with Ollama on Vertex AI and LangGraph.\n",
        "\n",
        "By the end of this notebook, you will learn how to:\n",
        "\n",
        "- Deploy Google Gemma 2 on Vertex AI using Ollama\n",
        "- Learn how to test the container using Vertex AI LocalModel class\n",
        "- Implement a simple RAG agent application with Gemma 2 and Ollama using LangGraph"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "61RBz8LLbxCR"
      },
      "source": [
        "## Get started"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "No17Cw5hgx12"
      },
      "source": [
        "### Install Vertex AI SDK and other required packages\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tFy3H3aPgx12"
      },
      "outputs": [],
      "source": [
        "%pip install --upgrade --user --quiet \"huggingface_hub\" \\\n",
        "    \"google-cloud-aiplatform[prediction]\" \\\n",
        "    \"torch\" \\\n",
        "    \"etils\" \\\n",
        "    \"crcmod\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "R5Xep4W9lq-Z"
      },
      "source": [
        "### Restart runtime\n",
        "\n",
        "To use the newly installed packages in this Jupyter runtime, you must restart the runtime. You can do this by running the cell below, which restarts the current kernel.\n",
        "\n",
        "The restart might take a minute or longer. After it's restarted, continue to the next step."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XRvKdaPDTznN"
      },
      "outputs": [],
      "source": [
        "import IPython\n",
        "\n",
        "app = IPython.Application.instance()\n",
        "app.kernel.do_shutdown(True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SbmM4z7FOBpM"
      },
      "source": [
        "<div class=\"alert alert-block alert-warning\">\n",
        "<b>⚠️ The kernel is going to restart. In Colab or Colab Enterprise, you might see an error message that says \"Your session crashed for an unknown reason.\" This is expected. Wait until it's finished before continuing to the next step. ⚠️</b>\n",
        "</div>\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dmWOrTJ3gx13"
      },
      "source": [
        "### Authenticate your notebook environment (Colab only)\n",
        "\n",
        "If you're running this notebook on Google Colab, run the cell below to authenticate your environment."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NyKGtVQjgx13"
      },
      "outputs": [],
      "source": [
        "# import sys\n",
        "\n",
        "# if \"google.colab\" in sys.modules:\n",
        "#     from google.colab import auth\n",
        "\n",
        "#     auth.authenticate_user()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ee37e1544281"
      },
      "source": [
        "### Requirements"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "877cd3fb2dce"
      },
      "source": [
        "You will need to have the following IAM roles set:\n",
        "\n",
        "- Artifact Registry Administrator (`roles/artifactregistry.admin`)\n",
        "- Cloud Build Editor (`roles/cloudbuild.builds.editor`)\n",
        "- Vertex AI User (`roles/aiplatform.user`)\n",
        "- Service Account User (`roles/iam.serviceAccountUser`)\n",
        "- Service Usage Consumer (`roles/serviceusage.serviceUsageConsumer`)\n",
        "- Storage Admin (`roles/storage.admin`)\n",
        "\n",
        "For more information about granting roles, see [Manage access](https://cloud.google.com/iam/docs/granting-changing-revoking-access).\n",
        "\n",
        "---\n",
        "\n",
        "You will also need to enable the following APIs (if not enabled already):\n",
        "\n",
        "- Artifact Registry API (artifactregistry.googleapis.com)\n",
        "- Vertex AI API (aiplatform.googleapis.com)\n",
        "- Compute Engine API (compute.googleapis.com)\n",
        "\n",
        "For more information about API enablement, see [Enabling APIs](https://cloud.google.com/apis/docs/getting-started#enabling_apis)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f9af3e57f89a"
      },
      "source": [
        "### Authenticate your Hugging Face account\n",
        "\n",
        "As [`google/gemma-2b-it`](https://huggingface.co/google/gemma-2-2b-it) is a gated model, you need to have a Hugging Face Hub account, and accept the Google's usage license for Gemma. Once that's done, you need to generate a new user access token with read-only access so that the weights can be downloaded from the Hub.\n",
        "\n",
        "> Note that the user access token can only be generated via [the Hugging Face Hub UI](https://huggingface.co/settings/tokens/new), where you can either select read-only access to your account, or follow the recommendations and generate a fine-grained token with read-only access to [`google/gemma-9b-it`](https://huggingface.co/google/gemma-2-2b-it)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4c31c7272804"
      },
      "source": [
        "Then you can install the `huggingface_hub` that comes with a CLI that will be used for the authentication with the token generated in advance. So that then the token can be safely retrieved via `huggingface_hub.get_token`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8d836e0210fe"
      },
      "outputs": [],
      "source": [
        "from huggingface_hub import interpreter_login\n",
        "\n",
        "interpreter_login()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "c71a4314c250"
      },
      "source": [
        "Read more about [Hugging Face Security](https://huggingface.co/docs/hub/en/security), specifically about [Hugging Face User Access Tokens](https://huggingface.co/docs/hub/en/security-tokens)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DF4l8DTdWgPY"
      },
      "source": [
        "### Set Google Cloud project information and initialize Vertex AI SDK\n",
        "\n",
        "To get started using Vertex AI, you must have an existing Google Cloud project and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com).\n",
        "\n",
        "Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Nqwi-5ufWp_B"
      },
      "outputs": [],
      "source": [
        "# Use the environment variable if the user doesn't provide Project ID.\n",
        "import os\n",
        "\n",
        "import vertexai\n",
        "\n",
        "PROJECT_ID = \"[your-project-id]\"  # @param {type: \"string\", placeholder: \"[your-project-id]\", isTemplate: true}\n",
        "\n",
        "if not PROJECT_ID or PROJECT_ID == \"[your-project-id]\":\n",
        "    PROJECT_ID = str(os.environ.get(\"GOOGLE_CLOUD_PROJECT\"))\n",
        "\n",
        "LOCATION = os.environ.get(\"GOOGLE_CLOUD_REGION\", \"us-central1\")\n",
        "\n",
        "BUCKET_NAME = \"[your-bucket-name]\"  # @param {type: \"string\", placeholder: \"[your-bucket-name]\", isTemplate: true}\n",
        "\n",
        "if not BUCKET_NAME or BUCKET_NAME == \"[your-bucket-name]\":\n",
        "    BUCKET_NAME = f\"{PROJECT_ID}-bucket\"\n",
        "\n",
        "BUCKET_URI = f\"gs://{BUCKET_NAME}\"\n",
        "\n",
        "! gsutil mb -p $PROJECT_ID -l $LOCATION $BUCKET_URI\n",
        "\n",
        "vertexai.init(project=PROJECT_ID, location=LOCATION, staging_bucket=BUCKET_URI)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BZGpwahJBkeo"
      },
      "source": [
        "### Set tutorial folder\n",
        "\n",
        "Define a folder for the tutorial."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XhFT7nQUBm23"
      },
      "outputs": [],
      "source": [
        "from etils import epath\n",
        "\n",
        "TUTORIAL_DIR = epath.Path(\"ollama_on_vertex_ai_tutorial\")\n",
        "BUILD_DIR = TUTORIAL_DIR / \"build\"\n",
        "MODELS_DIR = BUILD_DIR / \"ollama_models\"\n",
        "\n",
        "MODELS_DIR.mkdir(exist_ok=True, parents=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5303c05f7aa6"
      },
      "source": [
        "### Import libraries\n",
        "\n",
        "Import main libraries."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6fc324893334"
      },
      "outputs": [],
      "source": [
        "import gc\n",
        "import json\n",
        "\n",
        "from google.cloud import aiplatform\n",
        "from google.cloud.aiplatform import Endpoint, Model\n",
        "from google.cloud.aiplatform.prediction import LocalModel\n",
        "from huggingface_hub import get_token, snapshot_download\n",
        "import torch"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4-HLop9sq-BA"
      },
      "source": [
        "### Libraries settings"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cNF1FIlDrCtJ"
      },
      "outputs": [],
      "source": [
        "os.environ[\"HF_HUB_ENABLE_HF_TRANSFER\"] = \"1\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "E4rjOknk5yvq"
      },
      "source": [
        "### Helpers\n",
        "\n",
        "Define some helpers."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MOoaB5xf51BD"
      },
      "outputs": [],
      "source": [
        "def get_cuda_device_names():\n",
        "    \"\"\"A function to get the list of NVIDIA GPUs\"\"\"\n",
        "    if not torch.cuda.is_available():\n",
        "        return None\n",
        "\n",
        "    return [str(i) for i in range(torch.cuda.device_count())]\n",
        "\n",
        "\n",
        "def empty_gpu_ram():\n",
        "    gc.collect()\n",
        "    torch.cuda.empty_cache()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "njDb6J4FQihU"
      },
      "source": [
        "## Deploy Gemma 2 using Ollama on Vertex AI Prediction\n",
        "\n",
        "To deploy Gemma 2 as an Ollama model on Vertex AI Prediction, a custom container with the Ollama server and the Gemma 2 model is required. You can use Cloud Build, a serverless CI/CD platform, to build the serving container image."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-7QAx-XzRUmW"
      },
      "source": [
        "### Download Gemma 2 from Hugging Face Hub\n",
        "\n",
        "Download `google-cloud-partnership/gemma-2-2b-it-lora-sql` , a Gemma 2 adapter which allows you to handle both SQL user requests using Gemma 2."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hws_NxS-RZQR"
      },
      "outputs": [],
      "source": [
        "base_model_id = \"google-cloud-partnership/gemma-2-2b-it-lora-sql\"\n",
        "model_dir = MODELS_DIR / \"gemma-2-2b-it-lora-sql\"\n",
        "\n",
        "ignore_patterns = [\".gitattributes\", \".gitkeep\", \"*.md\"]\n",
        "\n",
        "snapshot_download(\n",
        "    repo_id=base_model_id,\n",
        "    token=get_token(),\n",
        "    local_dir=model_dir,\n",
        "    local_dir_use_symlinks=False,\n",
        "    ignore_patterns=ignore_patterns,\n",
        ")\n",
        "\n",
        "! rm -rf $model_dir/.cache"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IprOEAAN1sBQ"
      },
      "source": [
        "### Create Artifact Registry repository\n",
        "\n",
        "To build a container, create a repository in Google Cloud Artifact Registry."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "p5hXDtoYsCEB"
      },
      "outputs": [],
      "source": [
        "REPOSITORY_NAME = \"ollama-gemma-on-vertex\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "z1ZBM9PDrbdM"
      },
      "outputs": [],
      "source": [
        "!gcloud artifacts repositories create $REPOSITORY_NAME \\\n",
        "      --repository-format=docker \\\n",
        "      --location=$LOCATION \\\n",
        "      --project=$PROJECT_ID"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IDMpuXEu2thu"
      },
      "source": [
        "### Create a Dockerfile\n",
        "\n",
        "Use the following Dockerfile to define the container's build steps. The Dockerfile installs Python and Flask, sets environment variables, copies Ollama model files, exposes ports, and runs the Ollama model and a proxy server.\n",
        "\n",
        "> In this scenario, both Ollama and Fast API in the same container for simplicity."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Vi9T53CScWdn"
      },
      "outputs": [],
      "source": [
        "dockerfile = \"\"\"\n",
        "# Use multi-stage build for a smaller final image\n",
        "FROM ollama/ollama:0.5.5\n",
        "\n",
        "# Install Python and FastAPI\n",
        "RUN apt-get update && \\\n",
        "    apt-get install -y python3 python3-pip curl && \\\n",
        "    pip3 install fastapi uvicorn httpx asyncio\n",
        "\n",
        "# Set build-time arguments for better flexibility\n",
        "ARG OLLAMA_PORT=8079\n",
        "ARG SERVING_PORT=8080\n",
        "\n",
        "# Set environment variables\n",
        "ENV OLLAMA_HOST=0.0.0.0:${OLLAMA_PORT} \\\n",
        "    OLLAMA_MODELS=/ollama_models \\\n",
        "    OLLAMA_KEEP_ALIVE=-1 \\\n",
        "    OLLAMA_DEBUG=false\n",
        "\n",
        "# Copy model files\n",
        "COPY ./ollama_models /ollama_models\n",
        "COPY gemma-2-2b-it-lora-sql.modelfile .\n",
        "\n",
        "# Expose ollama port\n",
        "EXPOSE ${OLLAMA_PORT}\n",
        "\n",
        "# Create model in a proper way with health check\n",
        "RUN ollama serve & \\\n",
        "    sleep 5 && ollama create gemma-2-2b-it-lora-sql-2b -f gemma-2-2b-it-lora-sql.modelfile\n",
        "\n",
        "# Expose port\n",
        "EXPOSE ${SERVING_PORT}\n",
        "\n",
        "# Copy the proxy server code and entrypoint script\n",
        "COPY main.py .\n",
        "COPY entrypoint.sh .\n",
        "\n",
        "# Run the proxy server\n",
        "RUN chmod +x ./entrypoint.sh\n",
        "ENTRYPOINT [\"./entrypoint.sh\"]\n",
        "\"\"\"\n",
        "\n",
        "with BUILD_DIR.joinpath(\"Dockerfile\").open(\"w\") as f:\n",
        "    f.write(dockerfile)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IPPxybcDXNvn"
      },
      "source": [
        "### Create Modelfile\n",
        "\n",
        "Define an Ollama Modelfile which is the configuration file Ollama needs to define and use the Gemma 2 adapter model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YbHAP5gxXQjM"
      },
      "outputs": [],
      "source": [
        "modelfile = \"\"\"FROM gemma2:2b\n",
        "ADAPTER ollama_models/gemma-2-2b-it-lora-sql\n",
        "\"\"\"\n",
        "\n",
        "with BUILD_DIR.joinpath(\"gemma-2-2b-it-lora-sql.modelfile\").open(\"w\") as f:\n",
        "    f.write(modelfile)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZLUVwtidieQq"
      },
      "source": [
        "### Serve engine proxy\n",
        "\n",
        "This FastAPI application serves as a proxy between Vertex AI Endpoint and a local Ollama model server.\n",
        "\n",
        "It receives prediction requests from Vertex AI, forwards them to Ollama, and returns the responses to Vertex AI in a standardized format. The application also includes health checks, request validation, error handling, and asynchronous API calls.\n",
        "\n",
        "> In this scenario, the FastAPI application only maps the `generate` API."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OyfS_j3fiiFv"
      },
      "outputs": [],
      "source": [
        "app_module = \"\"\"\n",
        "'''\n",
        "FastAPI proxy for Vertex AI Endpoint running Ollama.\n",
        "'''\n",
        "\n",
        "import os\n",
        "from typing import Any, Dict, List, Optional\n",
        "from pydantic import BaseModel, Field\n",
        "import httpx\n",
        "from fastapi import FastAPI, HTTPException, status\n",
        "from fastapi.responses import JSONResponse\n",
        "import asyncio\n",
        "\n",
        "\n",
        "# Configuration\n",
        "class PredictionRequest(BaseModel):\n",
        "    '''Request model for predictions'''\n",
        "    instances: List[Dict] = Field(..., description=\"List of prediction instances\")\n",
        "\n",
        "class PredictionResponse(BaseModel):\n",
        "    '''Response model for predictions'''\n",
        "    predictions: List[str] = Field(..., description=\"List of model responses\")\n",
        "\n",
        "class Config:\n",
        "    '''Application configuration.'''\n",
        "    HEALTH_ROUTE: str = os.environ.get('AIP_HEALTH_ROUTE', '/health')\n",
        "    PREDICT_ROUTE: str = os.environ.get('AIP_PREDICT_ROUTE', '/predict')\n",
        "    PORT: int = int(os.environ.get('AIP_HTTP_PORT', '8080'))\n",
        "    OLLAMA_URL: str = os.environ.get('OLLAMA_URL', 'http://localhost:8079')\n",
        "    MODEL_NAME: str = os.environ.get('MODEL_NAME', 'gemma-2-2b-it-lora-sql-2b')\n",
        "    TIMEOUT: int = int(os.environ.get('TIMEOUT_SECONDS', '30'))\n",
        "\n",
        "# Helper function\n",
        "async def ollama_generate(prompt: str, parameters: Dict['str', Any]) -> str:\n",
        "    '''\n",
        "    Make a prediction using the Ollama model.\n",
        "    '''\n",
        "    async with httpx.AsyncClient(timeout=Config.TIMEOUT) as client:\n",
        "        try:\n",
        "            response = await client.post(\n",
        "                f\"{Config.OLLAMA_URL}/api/generate\",\n",
        "                json={\n",
        "                    \"prompt\": prompt,\n",
        "                    \"stream\": False,\n",
        "                    \"options\": parameters,\n",
        "                    \"model\": Config.MODEL_NAME\n",
        "                }\n",
        "            )\n",
        "            response.raise_for_status()\n",
        "            return response.json()[\"response\"]\n",
        "\n",
        "        except httpx.HTTPError as e:\n",
        "            raise HTTPException(\n",
        "                status_code=status.HTTP_502_BAD_GATEWAY,\n",
        "                detail=f\"Error calling Ollama: {str(e)}\"\n",
        "            )\n",
        "# Application\n",
        "app = FastAPI(\n",
        "    title=\"Ollama Vertex AI Proxy\",\n",
        "    description=\"A proxy service to run Ollama models on Vertex AI\"\n",
        ")\n",
        "\n",
        "@app.get(\n",
        "    Config.HEALTH_ROUTE,\n",
        "    response_model=Dict[str, str],\n",
        "    description=\"Health check endpoint\",\n",
        ")\n",
        "async def health() -> Dict[str, str]:\n",
        "    '''Check if the service is healthy.'''\n",
        "    return {'status': 'healthy'}\n",
        "\n",
        "@app.post(\n",
        "    Config.PREDICT_ROUTE,\n",
        "    response_model=PredictionResponse,\n",
        "    description=\"Make predictions using the Ollama model\",\n",
        ")\n",
        "async def predict(request: PredictionRequest) -> PredictionResponse:\n",
        "    '''Process predictions using the Ollama model concurrently.'''\n",
        "\n",
        "    if not request.instances:\n",
        "        raise HTTPException(\n",
        "            status_code=status.HTTP_400_BAD_REQUEST,\n",
        "            detail=\"No instances provided in request\"\n",
        "        )\n",
        "\n",
        "    try:\n",
        "        # Process all prompts concurrently\n",
        "        tasks = []\n",
        "        for instance in request.instances:\n",
        "            prompt = instance.get('inputs', '')\n",
        "            parameters = instance.get('parameters', {})\n",
        "            tasks.append(ollama_generate(prompt, parameters))\n",
        "            \n",
        "        # Wait for all requests to complete\n",
        "        predictions = await asyncio.gather(*tasks)\n",
        "        return PredictionResponse(predictions=predictions)\n",
        "\n",
        "    except Exception as e:\n",
        "        raise HTTPException(\n",
        "            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,\n",
        "            detail=f\"Error processing prediction: {str(e)}\"\n",
        "        )\n",
        "\n",
        "@app.exception_handler(HTTPException)\n",
        "async def http_exception_handler(request, exc):\n",
        "    '''Handle HTTP exceptions with a consistent format.'''\n",
        "    return JSONResponse(\n",
        "        status_code=exc.status_code,\n",
        "        content={\n",
        "            \"error\": {\n",
        "                \"code\": exc.status_code,\n",
        "                \"message\": exc.detail\n",
        "            }\n",
        "        }\n",
        "    )\n",
        "\n",
        "if __name__ == '__main__':\n",
        "    import uvicorn\n",
        "    uvicorn.run(\n",
        "        \"main:app\",\n",
        "        host=\"0.0.0.0\",\n",
        "        port=Config.PORT\n",
        "    )\n",
        "\"\"\"\n",
        "\n",
        "with BUILD_DIR.joinpath(\"main.py\").open(\"w\") as f:\n",
        "    f.write(app_module)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aIBHXaHMWlpN"
      },
      "source": [
        "### Entrypoint script\n",
        "\n",
        "Create an entrypoint script to startup FastAPI application and its Ollama service. The script launches Ollama (a local AI model server) in the background, verifies its readiness through health checks, and then initializes a FastAPI application to serve as the main interface."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PR8esJG0WpGe"
      },
      "outputs": [],
      "source": [
        "entrypoint_script = \"\"\"#!/bin/bash\n",
        "\n",
        "# Enable error handling\n",
        "set -e\n",
        "\n",
        "# Function to log messages with timestamps\n",
        "log() {\n",
        "    echo \"[$(date +'%Y-%m-%d %H:%M:%S')] $1\"\n",
        "}\n",
        "\n",
        "# Function to check if Ollama is ready\n",
        "check_ollama() {\n",
        "    for i in {1..30}; do\n",
        "        if curl -s http://localhost:8079 >/dev/null; then\n",
        "            log \"✅ Ollama is ready!\"\n",
        "            return 0\n",
        "        fi\n",
        "        log \"⏳ Waiting for Ollama to start... ($i/30)\"\n",
        "        sleep 1\n",
        "    done\n",
        "    log \"❌ Ollama failed to start within 30 seconds\"\n",
        "    return 1\n",
        "}\n",
        "\n",
        "# Start Ollama in the background\n",
        "log \"🚀 Starting Ollama...\"\n",
        "ollama serve & sleep 5\n",
        "\n",
        "# Wait for Ollama to be ready\n",
        "check_ollama\n",
        "\n",
        "# Start the FastAPI serving application\n",
        "log \"🚀 Starting FastAPI serving application...\"\n",
        "exec python3 /main.py\n",
        "\"\"\"\n",
        "\n",
        "with BUILD_DIR.joinpath(\"entrypoint.sh\").open(\"w\") as f:\n",
        "    f.write(entrypoint_script)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5RnaYx2p235W"
      },
      "source": [
        "### Build the container image with Cloud Build\n",
        "\n",
        "Use Cloud Build to build the container image.\n",
        "\n",
        "> The operation will take less than 5 minutes.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "27nddRobW-_Y"
      },
      "outputs": [],
      "source": [
        "SERVING_CONTAINER_IMAGE_URI = (\n",
        "    f\"{LOCATION}-docker.pkg.dev/{PROJECT_ID}/{REPOSITORY_NAME}/ollama-gemma-2-serve\"\n",
        ")\n",
        "\n",
        "! gcloud auth configure-docker $LOCATION-docker.pkg.dev --quiet\n",
        "! gcloud builds submit --tag $SERVING_CONTAINER_IMAGE_URI --project $PROJECT_ID --machine-type e2-highcpu-32 $BUILD_DIR"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xgcPjmMfXwAg"
      },
      "source": [
        "### (Optional) Testing Ollama container locally using serving container with Vertex AI LocalModel\n",
        "\n",
        "For debugging purpose, Vertex AI provides `LocalModel` class, accessible through the Vertex AI SDK for Python. This class allows you to build and deploy your model locally, simulating the Vertex AI environment. Using LocalModel involves creating a Docker image that encapsulates your custom predictor code and the associated handler.\n",
        "\n",
        "> **Important**: Running the LocalModel class requires a local Docker installation. This allows the model to be encapsulated within a container for consistent execution across different environments.\n",
        "\n",
        "> If you haven't already installed Docker Engine, please refer to the official installation guide: [Install Docker Engine](https://docs.docker.com/engine/install/). This documentation provides detailed instructions for various operating systems and will guide you through the installation process. Ensure Docker is running correctly before proceeding with the LocalModel examples.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Im5wedPQuMoY"
      },
      "source": [
        "#### Create a LocalModel instance\n",
        "\n",
        "Set up a local model by specifying the container image to use and the port it will listen on (8080).\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VS9Z75d9uW3x"
      },
      "outputs": [],
      "source": [
        "local_ollama_gemma_model = LocalModel(\n",
        "    serving_container_image_uri=SERVING_CONTAINER_IMAGE_URI,\n",
        "    serving_container_ports=[8080],\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3bxcGRl8uas-"
      },
      "source": [
        "#### Create a LocalEndpoint instance\n",
        "\n",
        "Deploy the model to a local endpoint for serving. The `gpu_device_ids` sets available GPUs if present.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hfbRbRHcuee3"
      },
      "outputs": [],
      "source": [
        "local_ollama_gemma_endpoint = local_ollama_gemma_model.deploy_to_local_endpoint(\n",
        "    gpu_device_ids=get_cuda_device_names()\n",
        ")\n",
        "\n",
        "local_ollama_gemma_endpoint.serve()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XmnfGmT8sQ-k"
      },
      "source": [
        "#### Monitoring Your Containerized Deployment\n",
        "\n",
        "To keep track of your container's deployment progress and identify any potential issues, you can use the following Docker commands within your terminal:\n",
        "\n",
        "1. **List all containers:** `docker container ls -a` displays a list of all running and stopped containers. Locate the container associated with your local endpoint and copy its ID.  This ID is essential for the next step.\n",
        "\n",
        "2. **Stream container logs:** `docker logs --follow <CONTAINER_ID>`  provides a real-time stream of your container's logs. Replace `<CONTAINER_ID>` with the ID you copied in the previous step. Monitoring these logs allows you to observe the deployment process, identify any errors or warnings, and understand the container's overall health."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7EcVi9Y_ujFj"
      },
      "source": [
        "#### Generate predictions\n",
        "\n",
        "Send a prediction request to a local Vertex AI endpoint.\n",
        "\n",
        "You convert the request data into a JSON string, send it to the endpoint, and then print the predictions from the JSON response.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CvkpSHzymXEz"
      },
      "outputs": [],
      "source": [
        "prediction_request = {\n",
        "    \"instances\": [\n",
        "        {\n",
        "            \"inputs\": \"How to run a select all query\",\n",
        "            \"parameters\": {\n",
        "                \"temperature\": 1.0,\n",
        "            },\n",
        "        },\n",
        "    ]\n",
        "}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Iso7dA7BzQiy"
      },
      "outputs": [],
      "source": [
        "vertex_prediction_request = json.dumps(prediction_request)\n",
        "vertex_prediction_response = local_ollama_gemma_endpoint.predict(\n",
        "    request=vertex_prediction_request, headers={\"Content-Type\": \"application/json\"}\n",
        ")\n",
        "print(vertex_prediction_response.json()[\"predictions\"])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3aa8a5163e2e"
      },
      "outputs": [],
      "source": [
        "vertex_prediction_response"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ecK96GglLTEb"
      },
      "source": [
        "### Register Ollama model on Vertex AI\n",
        "\n",
        "To serve Gemma 2 with Ollama on Vertex AI, import the model on Vertex AI Model Registry, a central repository where you can manage the lifecycle of your ML models on Vertex AI, using the `aiplatform.Model.upload` method.\n",
        "\n",
        "Some of the main arguments of the `aiplatform.Model.upload` are:\n",
        "\n",
        "- `display_name`: The name shown in the Vertex AI Model Registry.\n",
        "- `serving_container_image_uri`: The location of the Ollama container.\n",
        "- (Optional) `serving_container_ports`: The port where the Vertex AI endpoint will be exposed (default 8080).\n",
        "\n",
        "For more information on the supported `aiplatform.Model.upload` arguments, check [its Python reference](https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform.Model#google_cloud_aiplatform_Model_upload)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ef487ce082f3"
      },
      "outputs": [],
      "source": [
        "model = Model.upload(\n",
        "    display_name=\"google--gemma-2-2b-it-lora-sql-ollama\",\n",
        "    serving_container_image_uri=SERVING_CONTAINER_IMAGE_URI,\n",
        "    serving_container_ports=[8080],\n",
        ")\n",
        "model.wait()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "c427c0a87016"
      },
      "source": [
        "### Deploy Ollama model on Vertex AI\n",
        "\n",
        "After the model is registered on Vertex AI, deploy the model to a Vertex AI endpoint using the `aiplatform.Model.deploy` method.\n",
        "\n",
        "Some of the main arguments of the `aiplatform.Model.upload` are:\n",
        "\n",
        "* (optional) **`endpoint`** : Set an endpoint for model deployment.\n",
        "* **`machine_type, accelerator_type, accelerator_count`** : Define the deployment instance and accelerator configuration.\n",
        "\n",
        "For more information on the supported `aiplatform.Model.deploy` arguments, you can check [its Python reference](https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform.Model#google_cloud_aiplatform_Model_deploy)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JrcXG5KcozFm"
      },
      "source": [
        "> Note that the model deployment on Vertex AI can take around 15 to 25 minutes; most of the time being the allocation / reservation of the resources, setting up the network and security, and such."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "e777427015db"
      },
      "outputs": [],
      "source": [
        "endpoint = Endpoint.create(\n",
        "    display_name=\"google--gemma-2-2b-it-lora-sql-ollama-endpoint\"\n",
        ")\n",
        "\n",
        "deployed_model = model.deploy(\n",
        "    endpoint=endpoint,\n",
        "    machine_type=\"g2-standard-4\",\n",
        "    accelerator_type=\"NVIDIA_L4\",\n",
        "    accelerator_count=1,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9b3fd1898241"
      },
      "source": [
        "### Online predictions on Vertex AI"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2aa9cee03bd0"
      },
      "source": [
        "Once the model is deployed on Vertex AI, run the online predictions using the `aiplatform.Endpoint.predict` method, which will send the requests to the running endpoint in the `/predict` route specified within the container following Vertex AI I/O payload formatting."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gOrBkKoy-_2d"
      },
      "outputs": [],
      "source": [
        "output = deployed_model.predict(\n",
        "    instances=[\n",
        "        {\n",
        "            \"inputs\": \"How to run a select all query\",\n",
        "            \"parameters\": {\n",
        "                \"temperature\": 1.0,\n",
        "            },\n",
        "        },\n",
        "    ]\n",
        ")\n",
        "predictions = output.predictions\n",
        "print(predictions[0])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3u7fALoY_S50"
      },
      "source": [
        "## Build an agentic RAG application using Ollama model on Vertex AI with LangGraph\n",
        "\n",
        "After deployed the Ollama model on Vertex AI, consume the model to build an agentic RAG application using Ollama model on Vertex AI with LangGraph."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dxG6zfZ9TYa-"
      },
      "source": [
        "### Install additional libraries\n",
        "\n",
        "Install langgraph libraries."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zQd0iPF7Tb8R"
      },
      "outputs": [],
      "source": [
        "%pip install --upgrade --user --quiet \"langchain-community\" \\\n",
        "    \"langchainhub\" \\\n",
        "    \"langchain_google_vertexai\" \\\n",
        "    \"langgraph\" \\\n",
        "    \"faiss-gpu\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "y_0aWPbuOHyy"
      },
      "source": [
        "### Import additional libraries\n",
        "\n",
        "Import additional libraries to build the agent."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NiFX3teSNzZ-"
      },
      "outputs": [],
      "source": [
        "from typing import Any, TypedDict\n",
        "\n",
        "from IPython.display import Image, display\n",
        "from langchain_community.document_loaders import WebBaseLoader\n",
        "from langchain_community.vectorstores import FAISS\n",
        "from langchain_core.messages import AIMessage, BaseMessage, HumanMessage\n",
        "from langchain_core.runnables.graph import MermaidDrawMethod\n",
        "from langchain_google_vertexai.embeddings import VertexAIEmbeddings\n",
        "from langchain_text_splitters import RecursiveCharacterTextSplitter\n",
        "from langgraph.graph import END, StateGraph"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mx4LDoATRtXN"
      },
      "source": [
        "### Define some helpers\n",
        "\n",
        "Define a `CustomVertexAIModel` to handle all the endpoint formatting and parameter management in a way to make the model compatible with a LangGraph agentic workflow."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "V9-1wIkXR0bA"
      },
      "outputs": [],
      "source": [
        "class CustomVertexAIModel:\n",
        "    \"\"\"\n",
        "    A simple wrapper for Vertex AI Endpoints.\n",
        "    \"\"\"\n",
        "\n",
        "    def __init__(\n",
        "        self,\n",
        "        project: str,\n",
        "        location: str,\n",
        "        endpoint_id: str,\n",
        "        **model_params: Any,\n",
        "    ):\n",
        "        \"\"\"\n",
        "        Initialize the Model Garden client.\n",
        "\n",
        "        Args:\n",
        "            project: Google Cloud project ID\n",
        "            location: Model location (e.g., \"us-central1\")\n",
        "            endpoint_id: Vertex AI endpoint ID\n",
        "            **model_params: Model parameters (temperature, max_tokens, etc.)\n",
        "        \"\"\"\n",
        "        self.endpoint = aiplatform.Endpoint(\n",
        "            endpoint_name=f\"projects/{project}/locations/{location}/endpoints/{endpoint_id}\"\n",
        "        )\n",
        "        self.model_params = model_params\n",
        "\n",
        "    def invoke(\n",
        "        self,\n",
        "        prompt: str,\n",
        "        **kwargs: Any,\n",
        "    ) -> str:\n",
        "        \"\"\"\n",
        "        Invoke the model with a prompt and optional parameter overrides.\n",
        "\n",
        "        Args:\n",
        "            prompt: The input text prompt\n",
        "            **kwargs: Optional parameter overrides for this specific call\n",
        "\n",
        "        Returns:\n",
        "            The model's response as a string\n",
        "        \"\"\"\n",
        "        # Merge default parameters with any call-specific overrides\n",
        "        parameters = {**self.model_params, **kwargs}\n",
        "\n",
        "        instance = {\"inputs\": prompt, \"parameters\": parameters}\n",
        "\n",
        "        response = self.endpoint.predict([instance])\n",
        "        return response.predictions[0]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mdNXjMN6y7zg"
      },
      "source": [
        "### Build the LangGraph agent"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pyQ7xAFFTTqO"
      },
      "source": [
        "#### Initialize Vertex AI components\n",
        "\n",
        "Initialize the Google's embedding model and the LLM model hosted on the Vertex AI Endpoint.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LloVRJuNTbLP"
      },
      "outputs": [],
      "source": [
        "embeddings = VertexAIEmbeddings(model_name=\"text-embedding-005\", project=PROJECT_ID)\n",
        "\n",
        "llm = CustomVertexAIModel(\n",
        "    endpoint_id=endpoint.name,\n",
        "    temperature=1.0,\n",
        "    project=PROJECT_ID,\n",
        "    location=LOCATION,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Y9HSN0CETeTA"
      },
      "source": [
        "#### Define agent components\n",
        "\n",
        "According to LangGraph documentation, define an agent state and the following functions to build the agent:\n",
        "\n",
        "1. `create_vectorstore_from_urls`: Loads web pages, splits them into chunks, and creates a searchable vector database using FAISS embeddings.\n",
        "2. `retrieve`: Finds the 3 most similar document chunks from the vector store based on the user's query and adds them to the state context.\n",
        "3. `generate_response`: Takes the retrieved context and query, sends them to the LLM for processing, and updates the state with the response and conversation history.\n",
        "4. `should_rewrite`: Checks if the generated response is in proper SQL format by looking for SQL keywords.\n",
        "5. `rewrite_response`: Asks the LLM to reformat the response into a proper SQL query with comments and proper syntax"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qDnsnmAiR5GP"
      },
      "outputs": [],
      "source": [
        "class AgentState(TypedDict):\n",
        "    query: str\n",
        "    messages: list[BaseMessage]\n",
        "    context: str\n",
        "    response: str\n",
        "    chat_history: list[BaseMessage]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tLDKJ9J-Tifw"
      },
      "outputs": [],
      "source": [
        "def create_vectorstore_from_urls(urls: list[str]) -> FAISS:\n",
        "    \"\"\"Create a FAISS vectorstore from webpage contents\"\"\"\n",
        "    loader = WebBaseLoader(urls)\n",
        "    documents = loader.load()\n",
        "    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)\n",
        "    splits = text_splitter.split_documents(documents)\n",
        "    return FAISS.from_documents(splits, embeddings)\n",
        "\n",
        "\n",
        "def retrieve(state: AgentState, vectorstore: FAISS) -> AgentState:\n",
        "    \"\"\"Retrieve relevant documents\"\"\"\n",
        "    docs = vectorstore.similarity_search(state[\"query\"], k=3)\n",
        "    state[\"context\"] = \"\\n\".join(doc.page_content for doc in docs)\n",
        "    return state\n",
        "\n",
        "\n",
        "def generate_response(state: AgentState) -> AgentState:\n",
        "    \"\"\"Generate response using the Model Garden LLM\"\"\"\n",
        "    prompt = f\"\"\"Context: {state[\"context\"]}\n",
        "                 Question: {state[\"query\"]}\n",
        "                 Please provide a helpful response based on the context above.\"\"\"\n",
        "\n",
        "    response = llm.invoke(prompt)\n",
        "\n",
        "    state[\"response\"] = response\n",
        "    state[\"messages\"].append(AIMessage(content=response))\n",
        "    state[\"chat_history\"].extend(\n",
        "        [HumanMessage(content=state[\"query\"]), AIMessage(content=response)]\n",
        "    )\n",
        "    return state\n",
        "\n",
        "\n",
        "def should_rewrite(state: AgentState) -> AgentState:\n",
        "    \"\"\"Decide if the response needs rewriting\"\"\"\n",
        "    # First check if query is SQL-related\n",
        "    sql_keywords = [\n",
        "        \"sql\",\n",
        "        \"query\",\n",
        "        \"select\",\n",
        "        \"table\",\n",
        "        \"database\",\n",
        "        \"bigquery\",\n",
        "        \"join\",\n",
        "        \"where\",\n",
        "    ]\n",
        "    query_is_sql = any(keyword in state[\"query\"].lower() for keyword in sql_keywords)\n",
        "\n",
        "    # Only check SQL formatting if the query was SQL-related\n",
        "    if query_is_sql:\n",
        "        response = state[\"response\"].lower()\n",
        "        needs_rewrite = (\n",
        "            not response.strip().startswith(\"select\")\n",
        "            and not response.strip().startswith(\"with\")\n",
        "            and not response.strip().startswith(\"create\")\n",
        "            and not response.strip().startswith(\"/*\")\n",
        "            and \"select\" not in response[:100]\n",
        "        )\n",
        "        state[\"next\"] = \"rewrite\" if needs_rewrite else \"end\"\n",
        "    else:\n",
        "        state[\"next\"] = \"end\"\n",
        "\n",
        "    return state\n",
        "\n",
        "\n",
        "def rewrite_response(state: AgentState) -> AgentState:\n",
        "    \"\"\"\n",
        "    Rewrite the response to ensure it's in proper SQL format\n",
        "    \"\"\"\n",
        "    prompt = f\"\"\"\n",
        "    Original question: {state[\"query\"]}\n",
        "    Previous response: {state[\"response\"]}\n",
        "\n",
        "    Rewrite the above as a proper SQL query following these rules:\n",
        "    - Start with SQL keywords (SELECT, WITH, CREATE, etc.)\n",
        "    - Include comments explaining the logic\n",
        "    - Format the query properly\n",
        "    - Use BigQuery SQL syntax\n",
        "    \"\"\"\n",
        "\n",
        "    new_response = llm.invoke(prompt)\n",
        "    state[\"response\"] = new_response\n",
        "    state[\"messages\"][-1] = AIMessage(content=new_response)  # Replace last message\n",
        "    state[\"chat_history\"][-1] = AIMessage(\n",
        "        content=new_response\n",
        "    )  # Replace last history item\n",
        "\n",
        "    return state"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FEzPObd4mV-4"
      },
      "source": [
        "#### Assemble the agent\n",
        "\n",
        "Create a simple agentic RAG that first builds a searchable database from URLs, then sets up a sequence of steps (retrieve → generate → check for rewrite → either end or rewrite and loop back) to process queries and generate SQL responses."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EVysTj9BmVRA"
      },
      "outputs": [],
      "source": [
        "def create_rag_agent(urls: list[str]) -> Any:\n",
        "    \"\"\"Create the RAG agent workflow\"\"\"\n",
        "    vectorstore = create_vectorstore_from_urls(urls)\n",
        "\n",
        "    workflow = StateGraph(AgentState)\n",
        "\n",
        "    # Add nodes\n",
        "    workflow.add_node(\"retrieve\", lambda s: retrieve(s, vectorstore))\n",
        "    workflow.add_node(\"generate\", generate_response)\n",
        "    workflow.add_node(\"should_rewrite\", should_rewrite)\n",
        "    workflow.add_node(\"rewrite\", rewrite_response)\n",
        "\n",
        "    # Add edges\n",
        "    workflow.add_edge(\"retrieve\", \"generate\")\n",
        "    workflow.add_edge(\"generate\", \"should_rewrite\")\n",
        "\n",
        "    # Add conditional edges\n",
        "    workflow.add_conditional_edges(\n",
        "        \"should_rewrite\", lambda x: x[\"next\"], {\"rewrite\": \"rewrite\", \"end\": END}\n",
        "    )\n",
        "\n",
        "    workflow.add_edge(\"rewrite\", \"retrieve\")\n",
        "\n",
        "    workflow.set_entry_point(\"retrieve\")\n",
        "\n",
        "    return workflow.compile()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PvFPaW4H0X5o"
      },
      "source": [
        "#### Initialize the agent\n",
        "\n",
        "Initialize the agent by passing the BigQuery documentation to ground the agent."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cptGsg_mmeCP"
      },
      "outputs": [],
      "source": [
        "urls = [\n",
        "    \"https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax\",\n",
        "    \"https://cloud.google.com/bigquery/docs/reference/standard-sql/dml-syntax\",\n",
        "    \"https://cloud.google.com/bigquery/docs/reference/standard-sql/arrays\",\n",
        "    \"https://cloud.google.com/bigquery/docs/reference/standard-sql/aggregate_functions\",\n",
        "    \"https://cloud.google.com/bigquery/docs/reference/standard-sql/window-function-calls\",\n",
        "    \"https://cloud.google.com/bigquery/docs/reference/standard-sql/subqueries\",\n",
        "    \"https://cloud.google.com/bigquery/docs/reference/standard-sql/joins\",\n",
        "    \"https://cloud.google.com/bigquery/docs/reference/standard-sql/operators\",\n",
        "    \"https://cloud.google.com/bigquery/docs/reference/standard-sql/functions-reference\",\n",
        "]\n",
        "\n",
        "\n",
        "agent = create_rag_agent(urls)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xSY4cq1N0c6b"
      },
      "source": [
        "#### Visualize the agent\n",
        "\n",
        "Plot the agentic workflow.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FklWq97Rmfdl"
      },
      "outputs": [],
      "source": [
        "display(\n",
        "    Image(\n",
        "        agent.get_graph().draw_mermaid_png(\n",
        "            draw_method=MermaidDrawMethod.API,\n",
        "        )\n",
        "    )\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RCH92w2cnB9B"
      },
      "source": [
        "### Query the agent\n",
        "\n",
        "Use the agent to answer SQL code generation user requests."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Tk4jB1qxUv-l"
      },
      "outputs": [],
      "source": [
        "def query_rag_agent(agent, query: str) -> dict[str, Any]:\n",
        "    \"\"\"Query the RAG agent\"\"\"\n",
        "    state = {\n",
        "        \"query\": query,\n",
        "        \"messages\": [],\n",
        "        \"context\": \"\",\n",
        "        \"response\": \"\",\n",
        "        \"chat_history\": [],\n",
        "        \"next\": None,  # Add next step field\n",
        "    }\n",
        "    return agent.invoke(state)\n",
        "\n",
        "\n",
        "questions = [\n",
        "    \"Write a SQL query to calculate the total sales per month and include a 3-month moving average\",\n",
        "    \"Create a query that finds the top 5 customers by revenue in each region, including their total spend and number of orders\",\n",
        "    \"Write a SQL query to analyze user engagement: calculate daily active users (DAU), weekly active users (WAU), and the DAU/WAU ratio for the last 30 days\",\n",
        "]\n",
        "\n",
        "# Test each question\n",
        "for question in questions:\n",
        "    result = query_rag_agent(agent, question)\n",
        "    print(f\"\\nQuestion: {question}\")\n",
        "    print(\"\\nResponse:\")\n",
        "    print(result[\"response\"])\n",
        "    print(\"\\n\" + \"=\" * 80)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f6f17f9aff65"
      },
      "source": [
        "## Cleaning up"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KFOXWLLTvMvJ"
      },
      "outputs": [],
      "source": [
        "delete_endpoint = False\n",
        "delete_model = False\n",
        "delete_artifact_registry = False\n",
        "delete_tutorial_folder = False\n",
        "\n",
        "if delete_endpoint:\n",
        "    deployed_model.undeploy_all()\n",
        "    deployed_model.delete()\n",
        "\n",
        "if delete_model:\n",
        "    delete_model.delete()\n",
        "\n",
        "if delete_artifact_registry:\n",
        "    ! gcloud artifacts repositories delete $REPOSITORY_NAME \\\n",
        "          --repository-format=docker \\\n",
        "          --location=$LOCATION \\\n",
        "          --project=$PROJECT_ID\n",
        "\n",
        "if delete_tutorial_folder:\n",
        "    import shutil\n",
        "\n",
        "    shutil.rmtree(tutorial_folder)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "vertex_ai_ollama_gemma2_rag_agent.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
